{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how a hierarchical Variational Auto-Encoder (VAE) works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xQyrkrqAL7p8"
   },
   "source": [
    "## Auxiliary functions and classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hendrycks, D., & Gimpel, K. (2016). Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415.\n",
    "class GELU(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(1.702 * x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tw00sH-6L9yg"
   },
   "source": [
    "### Distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AJh8NiXxLNf9"
   },
   "outputs": [],
   "source": [
    "PI = torch.from_numpy(np.asarray(np.pi))\n",
    "EPS = 1.e-7\n",
    "\n",
    "def log_categorical(x, p, num_classes=256, reduction=None, dim=None):\n",
    "    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)\n",
    "    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))\n",
    "    if reduction == 'avg':\n",
    "        return torch.mean(log_p, dim)\n",
    "    elif reduction == 'sum':\n",
    "        return torch.sum(log_p, dim)\n",
    "    else:\n",
    "        return log_p\n",
    "\n",
    "def log_bernoulli(x, p, reduction=None, dim=None):\n",
    "    pp = torch.clamp(p, EPS, 1. - EPS)\n",
    "    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)\n",
    "    if reduction == 'avg':\n",
    "        return torch.mean(log_p, dim)\n",
    "    elif reduction == 'sum':\n",
    "        return torch.sum(log_p, dim)\n",
    "    else:\n",
    "        return log_p\n",
    "\n",
    "def log_normal_diag(x, mu, log_var, reduction=None, dim=None):\n",
    "    log_p = -0.5 * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.\n",
    "    if reduction == 'avg':\n",
    "        return torch.mean(log_p, dim)\n",
    "    elif reduction == 'sum':\n",
    "        return torch.sum(log_p, dim)\n",
    "    else:\n",
    "        return log_p\n",
    "\n",
    "def log_standard_normal(x, reduction=None, dim=None):\n",
    "    log_p = -0.5 * torch.log(2. * PI) - 0.5 * x**2.\n",
    "    if reduction == 'avg':\n",
    "        return torch.mean(log_p, dim)\n",
    "    elif reduction == 'sum':\n",
    "        return torch.sum(log_p, dim)\n",
    "    else:\n",
    "        return log_p\n",
    "\n",
    "# Chakraborty & Chakravarty, \"A new discrete probability distribution with integer support on (−∞, ∞)\",\n",
    "#  Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743\n",
    "\n",
    "def log_min_exp(a, b, epsilon=1e-8):\n",
    "    \"\"\"\n",
    "    Source: https://github.com/jornpeters/integer_discrete_flows\n",
    "    Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.\n",
    "    Using:\n",
    "    log(exp(a) - exp(b))\n",
    "    c + log(exp(a-c) - exp(b-c))\n",
    "    a + log(1 - exp(b-a))\n",
    "    And note that we assume b < a always.\n",
    "    \"\"\"\n",
    "    y = a + torch.log(1 - torch.exp(b - a) + epsilon)\n",
    "\n",
    "    return y\n",
    "\n",
    "def log_integer_probability(x, mean, logscale):\n",
    "    scale = torch.exp(logscale)\n",
    "\n",
    "    logp = log_min_exp(\n",
    "      F.logsigmoid((x + 0.5 - mean) / scale),\n",
    "      F.logsigmoid((x - 0.5 - mean) / scale))\n",
    "\n",
    "    return logp\n",
    "\n",
    "def log_integer_probability_standard(x):\n",
    "    logp = log_min_exp(\n",
    "      F.logsigmoid(x + 0.5),\n",
    "      F.logsigmoid(x - 0.5))\n",
    "\n",
    "    return logp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## Hierarchical Variational Auto-Encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE** Please note that we use a single class unlike in previous implementations of VAEs. We do it for clarity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GRYA6JA4LWEC"
   },
   "outputs": [],
   "source": [
    "class HierarchicalVAE(nn.Module):\n",
    "    def __init__(self, nn_r_1, nn_r_2, nn_delta_1, nn_delta_2, nn_z_1, nn_x, num_vals=256, D=64, L=16, likelihood_type='categorical'):\n",
    "        super(HierarchicalVAE, self).__init__()\n",
    "\n",
    "        print('Hierachical VAE by JT.')\n",
    "        \n",
    "        # bottom-up path\n",
    "        self.nn_r_1 = nn_r_1\n",
    "        self.nn_r_2 = nn_r_2\n",
    "        \n",
    "        self.nn_delta_1 = nn_delta_1\n",
    "        self.nn_delta_2 = nn_delta_2\n",
    "        \n",
    "        # top-down path\n",
    "        self.nn_z_1 = nn_z_1\n",
    "        self.nn_x = nn_x\n",
    "\n",
    "        \n",
    "        # other params\n",
    "        self.D = D\n",
    "        \n",
    "        self.L = L\n",
    "        \n",
    "        self.num_vals = num_vals\n",
    "        \n",
    "        self.likelihood_type = likelihood_type\n",
    "\n",
    "    def reparameterization(self, mu, log_var):\n",
    "        std = torch.exp(0.5*log_var)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + std * eps\n",
    "        \n",
    "    def forward(self, x, reduction='avg'):\n",
    "        #=====\n",
    "        # bottom-up\n",
    "        # step 1\n",
    "        r_1 = self.nn_r_1(x)\n",
    "        r_2 = self.nn_r_2(r_1)\n",
    "        \n",
    "        #step 2\n",
    "        delta_1 = self.nn_delta_1(r_1)\n",
    "        delta_mu_1, delta_log_var_1 = torch.chunk(delta_1, 2, dim=1)\n",
    "        delta_log_var_1 = F.hardtanh(delta_log_var_1, -7., 2.)\n",
    "        \n",
    "        # step 3\n",
    "        delta_2 = self.nn_delta_2(r_2)\n",
    "        delta_mu_2, delta_log_var_2 = torch.chunk(delta_2, 2, dim=1)\n",
    "        delta_log_var_2 = F.hardtanh(delta_log_var_2, -7., 2.)\n",
    "        \n",
    "        # top-down\n",
    "        # step 4\n",
    "        z_2 = self.reparameterization(delta_mu_2, delta_log_var_2)\n",
    "        \n",
    "        # step 5\n",
    "        h_1 = self.nn_z_1(z_2)\n",
    "        mu_1, log_var_1 = torch.chunk(h_1, 2, dim=1)\n",
    "        \n",
    "        # step 6\n",
    "        z_1 = self.reparameterization(mu_1 + delta_mu_1, log_var_1 + delta_log_var_1)\n",
    "        \n",
    "        # step 7\n",
    "        h_d = self.nn_x(z_1)\n",
    "\n",
    "        if self.likelihood_type == 'categorical':\n",
    "            b = h_d.shape[0]\n",
    "            d = h_d.shape[1]//self.num_vals\n",
    "            h_d = h_d.view(b, d, self.num_vals)\n",
    "            mu_d = torch.softmax(h_d, 2)\n",
    "\n",
    "        elif self.likelihood_type == 'bernoulli':\n",
    "            mu_d = torch.sigmoid(h_d)\n",
    "        \n",
    "        #=====ELBO\n",
    "        # RE\n",
    "        if self.likelihood_type == 'categorical':\n",
    "            RE = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)\n",
    "\n",
    "        elif self.likelihood_type == 'bernoulli':\n",
    "            RE = log_bernoulli(x, mu_d, reduction='sum', dim=-1)\n",
    "        \n",
    "        # KL\n",
    "        KL_z_2 = 0.5 * (delta_mu_2**2 + torch.exp(delta_log_var_2) - delta_log_var_2 - 1).sum(-1)\n",
    "        KL_z_1 = 0.5 * (delta_mu_1**2 / torch.exp(log_var_1) + torch.exp(delta_log_var_1) -\\\n",
    "                        delta_log_var_1 - 1).sum(-1)\n",
    "        \n",
    "        KL = KL_z_1 + KL_z_2\n",
    "        \n",
    "        error = 0\n",
    "        if np.isnan(RE.detach().numpy()).any():\n",
    "            print('RE {}'.format(RE))\n",
    "            print('KL {}'.format(KL))\n",
    "            error = 1\n",
    "        if np.isnan(KL.detach().numpy()).any():\n",
    "            print('RE {}'.format(RE))            \n",
    "            print('KL {}'.format(KL))\n",
    "            error = 1\n",
    "\n",
    "        if error == 1:\n",
    "            raise ValueError()\n",
    "        \n",
    "        # Final ELBO\n",
    "        if reduction == 'sum':\n",
    "            loss = -(RE - KL).sum()\n",
    "        else:\n",
    "            loss = -(RE - KL).mean()\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        # step 1\n",
    "        z_2 = torch.randn(batch_size, self.L)\n",
    "        # step 2\n",
    "        h_1 = self.nn_z_1(z_2)\n",
    "        mu_1, log_var_1 = torch.chunk(h_1, 2, dim=1)\n",
    "        # step 3\n",
    "        z_1 = self.reparameterization(mu_1, log_var_1)\n",
    "        \n",
    "        # step 4\n",
    "        h_d = self.nn_x(z_1)\n",
    "        \n",
    "        if self.likelihood_type == 'categorical':\n",
    "            b = batch_size\n",
    "            d = h_d.shape[1]//self.num_vals\n",
    "            h_d = h_d.view(b, d, self.num_vals)\n",
    "            mu_d = torch.softmax(h_d, 2)\n",
    "            # step 5\n",
    "            p = mu_d.view(-1, self.num_vals)\n",
    "            x_new = torch.multinomial(p, num_samples=1).view(b, d)\n",
    "\n",
    "        elif self.likelihood_type == 'bernoulli':\n",
    "            mu_d = torch.sigmoid(h_d)\n",
    "            # step 5\n",
    "            x_new = torch.bernoulli(mu_d)\n",
    "        return x_new"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = model_best.forward(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}')\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    x = next(iter(data_loader)).detach().numpy()\n",
    "\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(batch_size=num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_curve(name, nll_val):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "\n",
    "                samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "\n",
    "    return nll_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms)\n",
    "val_data = Digits(mode='val', transforms=transforms)\n",
    "test_data = Digits(mode='test', transforms=transforms)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "D = 64   # input dimension\n",
    "L = 8    # number of latents\n",
    "M = 256  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "name = 'vae_hierarchical' + '_' + str(L)\n",
    "result_dir ='results/' + name + '/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model: (i) determining the conditional likelihood distribution, (ii) defininig encoder and decoder nets, and a prior**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "likelihood_type = 'categorical'\n",
    "\n",
    "if likelihood_type == 'categorical':\n",
    "    num_vals = 17\n",
    "elif likelihood_type == 'bernoulli':\n",
    "    num_vals = 1\n",
    "\n",
    "\n",
    "nn_r_1 = nn.Sequential(nn.Linear(D, M), GELU(),\n",
    "                           nn.BatchNorm1d(M),\n",
    "                       nn.Linear(M, M), nn.GELU()\n",
    "                      )\n",
    "\n",
    "nn_r_2 = nn.Sequential(nn.Linear(M, M), GELU(),\n",
    "                           nn.BatchNorm1d(M),\n",
    "                       nn.Linear(M, M), nn.LeakyReLU()\n",
    "                      )\n",
    "\n",
    "nn_delta_1 = nn.Sequential(nn.Linear(M, M), GELU(),\n",
    "                           nn.BatchNorm1d(M),\n",
    "                           nn.Linear(M, 2 * (L * 2)),\n",
    "                           )\n",
    "\n",
    "nn_delta_2 = nn.Sequential(nn.Linear(M, M), GELU(),\n",
    "                           nn.BatchNorm1d(M),\n",
    "                           nn.Linear(M, 2 * L),\n",
    "                           )\n",
    "\n",
    "nn_z_1 = nn.Sequential(nn.Linear(L, M), GELU(),\n",
    "                       nn.BatchNorm1d(M),\n",
    "                       nn.Linear(M, 2 * (L * 2))\n",
    "                      )\n",
    "\n",
    "nn_x = nn.Sequential(nn.Linear(L * 2, M), GELU(),\n",
    "                     nn.BatchNorm1d(M),\n",
    "                     nn.Linear(M,M), GELU(),\n",
    "                     nn.BatchNorm1d(M),\n",
    "                     nn.Linear(M, D * num_vals)\n",
    "                    )\n",
    "\n",
    "\n",
    "# Eventually, we initialize the full model\n",
    "model = HierarchicalVAE(nn_r_1, nn_r_2, nn_delta_1, nn_delta_2, nn_z_1, nn_x, num_vals=num_vals, D=D, L=L, likelihood_type=likelihood_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42"
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)\n",
    "\n",
    "samples_generated(result_dir + name, test_loader, extra_name='FINAL')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
